-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for linear-time mmd estimator. #475
base: master
Are you sure you want to change the base?
Conversation
Do we think users are ever going to have equal reference and test batch sizes in practice? I'd guess almost always the reference set is going to be much larger. I wonder if we'd be better off using the B-stat estimator by default for the linear case rather than Gretton's estimator for equal sample sizes. This additionally has the advantage of a tunable parameter that allows for interpolation between a linear and quadratic time estimator. Edit: It's actually not quite this simple. However I think we should put some thought into how best to address the n!=m case. |
Agree, maybe we can leave the current PR as it is for the linear-time one, and do a separate one for the additional B-stat implementation. |
…linear_time_mmd # Conflicts: # alibi_detect/cd/pytorch/mmd.py
…eshold with the linear-time estimator, instead of permutation.
Now the linear-time estimator also uses Gaussian under null for the test threshold, so no permutation is required. It should be the fastest at the cost of lower test power and some unused samples. |
def forward(self, x: Union[np.ndarray, torch.Tensor], | ||
y: Union[np.ndarray, torch.Tensor], | ||
infer_sigma: bool = False, | ||
diag: bool = False) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that they refer to the same thing, perhaps we could keep consistency between this kwarg name and the naming convention adopted for the squared distance functions? So perhaps pairwise: bool = True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, now uses pairwise
as suggested.
|
||
if infer_sigma or self.init_required: | ||
if self.trainable and infer_sigma: | ||
raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value") | ||
sigma = self.init_sigma_fn(x, y, dist) | ||
if not diag: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a good default behaviour to have? Could we end up with O(n^2) costs in places where the linear time estimator is being used specifically because such a cost would be infeasible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now directly use the median of the non-pairwise distance.
@@ -69,15 +69,24 @@ def __init__( | |||
def sigma(self) -> tf.Tensor: | |||
return tf.math.exp(self.log_sigma) | |||
|
|||
def call(self, x: tf.Tensor, y: tf.Tensor, infer_sigma: bool = False) -> tf.Tensor: | |||
def call(self, x: tf.Tensor, y: tf.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments on pytorch version
@@ -93,7 +115,43 @@ def batch_compute_kernel_matrix( | |||
return k_mat | |||
|
|||
|
|||
def mmd2_from_kernel_matrix(kernel_mat: torch.Tensor, m: int, permute: bool = False, | |||
def linear_mmd2(x: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick but probs worth keeping indentation within function definitions consistent with all of the other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, should be consistent all across now.
alibi_detect/cd/mmd.py
Outdated
@@ -18,6 +18,7 @@ def __init__( | |||
x_ref: Union[np.ndarray, list], | |||
backend: str = 'tensorflow', | |||
p_val: float = .05, | |||
estimator: str = 'quad', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would estimator_complexity
be more descriptive? (Or at least make clear in the docstring)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added extra description in the docstring.
k_yz = kernel(x=y[0::2, :], y=x[1::2, :], diag=True) | ||
|
||
h = k_xx + k_yy - k_xy - k_yz | ||
mmd2 = h.sum() / (n / 2.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we don't just use h.mean()
and h.var()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now uses h.mean()
, and torch.var(, unbiased=True)
in the torch version. TF version uses tf.reduce_mean
and manual correction.
def linear_mmd2(x: tf.Tensor, | ||
y: tf.Tensor, | ||
kernel: Callable, | ||
permute: bool = False) -> Tuple[tf.Tensor, tf.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we offer permute option for tensorflow and not torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Legacy issue, now removed for the tensorflow version.
k_xx = kernel(x_hat[0::2, :], x_hat[1::2, :], diag=True) | ||
k_yy = kernel(y_hat[0::2, :], y_hat[1::2, :], diag=True) | ||
k_xy = kernel(x_hat[0::2, :], y_hat[1::2, :], diag=True) | ||
k_yz = kernel(y_hat[0::2, :], x_hat[1::2, :], diag=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like unnecessary duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
alibi_detect/cd/pytorch/mmd.py
Outdated
mmd2 = mmd2.numpy().item() | ||
var_mmd2 = var_mmd2.numpy().item() | ||
std_mmd2 = np.sqrt(var_mmd2) | ||
p_val = 1 - stats.norm.cdf(mmd2 * np.sqrt(n_hat), loc=0., scale=std_mmd2*np.sqrt(2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick but should this be a t-test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice spot, now fixed with t-test for both versions.
mmd2 = mmd2.cpu() | ||
mmd2 = mmd2.numpy().item() | ||
var_mmd2 = var_mmd2.numpy().item() | ||
std_mmd2 = np.sqrt(var_mmd2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can directly use torch.std(...)
in linear_mmd2
? This would remove the few additional lines of code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new version uses np.sqrt(np.clip(var_mmd2, 1e-8, 1e-8))
for numeric stability.
@@ -30,6 +30,28 @@ def squared_pairwise_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1 | |||
return dist.clamp_min_(a_min) | |||
|
|||
|
|||
def squared_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just apply a reduction to the squared_pairwise_distance
instead of using an extra function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now implemented as a single function.
m = np.shape(y)[0] | ||
if n != m: | ||
raise RuntimeError("Linear-time estimator requires equal size samples") | ||
k_xx = kernel(x=x[0::2, :], y=x[1::2, :], pairwise=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to do this at init time (so self.k_xx
becomes useful again), saving compute at prediction time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, now the kernel matrix is reused for prediction.
""" | ||
n = np.shape(x)[0] | ||
m = np.shape(y)[0] | ||
if n != m: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This behaviour should in my opinion already be checked beforehand (see comment in the method itself).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
k_xx = kernel(x=x[0::2, :], y=x[1::2, :], pairwise=False) | ||
k_yy = kernel(x=y[0::2, :], y=y[1::2, :], pairwise=False) | ||
k_xy = kernel(x=x[0::2, :], y=y[1::2, :], pairwise=False) | ||
k_yz = kernel(x=y[0::2, :], y=x[1::2, :], pairwise=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is k_yz
the paper notation? B/c it might be easier to follow by just calling it k_yx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo, thanks for noticing, fixed.
@@ -68,16 +68,24 @@ def __init__( | |||
def sigma(self) -> torch.Tensor: | |||
return self.log_sigma.exp() | |||
|
|||
def forward(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], | |||
infer_sigma: bool = False) -> torch.Tensor: | |||
def forward(self, x: Union[np.ndarray, torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking big time here, but let's try to keep same type of indentation as e.g. in the DeepKernel
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
x, y = torch.as_tensor(x), torch.as_tensor(y) | ||
dist = distance.squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny] | ||
if pairwise: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check my comment in distance.py
which might make this if else
redundant and reduce it to a kwarg of the distance function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, now as part of the squared_pairwise_distance
function argument.
if pairwise: | ||
sigma = self.init_sigma_fn(x, y, dist) | ||
else: | ||
sigma = (.5 * dist.flatten().sort().values[dist.shape[0] // 2 - 1].unsqueeze(dim=-1)) ** .5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again I think we can avoid the hard-coding of this behaviour and fall back on self.init_sigma_fn
but with the desired linear detector behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly tricky as the default init_sigma_fn
is used by other detectors. Might be easier to keep the additional line here?
Left a number of comments related to the PyTorch implementation. Let's work through those first and then we can apply the desired changes to TensorFlow as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Requesting changes" to ensure we do not merge until #489 has been merged and predict
updated in this PR.
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -20,14 +26,44 @@ def squared_pairwise_distance(x: tf.Tensor, y: tf.Tensor, a_min: float = 1e-30, | |||
Lower bound to clip distance values. | |||
a_max | |||
Upper bound to clip distance values. | |||
|
|||
pairwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it a bit unclear to have a function named squared_pairwise_distance
that optionally computes non-pairwise distances? Perhaps squared_pairwise_distance
should be renamed squared_distance
? Or the pairwise=False
functionality separated out into a separate distance function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the same. It was previously a separate function and @arnaudvl was suggesting making the repeated parts minimal. Guess changing the function name across all related methods would be preferable.
TF version is also fixed for the above ones replied with "fixed". |
…IO#489). Merge branch 'master' into linear_time_mmd # Conflicts: # alibi_detect/cd/base.py # alibi_detect/cd/mmd.py # alibi_detect/cd/pytorch/mmd.py # alibi_detect/cd/tensorflow/mmd.py
@Srceh I have now merged in the v0.10.0 related changes from |
if self.device.type == 'cuda': | ||
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu() | ||
p_val = (mmd2 <= mmd2_permuted).float().mean() | ||
# compute distance threshold | ||
idx_threshold = int(self.p_val * len(mmd2_permuted)) | ||
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold] | ||
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy() | ||
|
||
|
||
class LinearTimeMMDDriftTorch(BaseMMDDrift): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these new subclasses don't make use of self.n_permutations
(set in BaseMMDDrift
), shall we set this to None
? I had a moment of confusion when updating the tests since self.n_permuations == 100
when estimator == 'linear'
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The default number of permutations then can be initialised in /cd/mmd.py
when estimator is 'quad'
.
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the logic to set self._detector
is located here, we should add additional tests to alibi_detect/cd/tests/test_mmd.py
to check that the correct subclass is selected conditional on backend
and estimator
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, will modify the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply rewrite the test to go through different backend
and estimator
options, should do the job.
|
This PR implements the linear-time estimator in (Lemma14 in paper), as asked in #288.